In [5]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import glob
import matplotlib.pyplot as plt
import seaborn as sns
import math

Data Preprocessing:¶

In [6]:
adata = sc.read_h5ad("/vast/palmer/pi/xiting_yan/hw568/collections_spatial_datasets/spatialDLPFC_new/adata_vis_after.h5ad")
In [7]:
image = adata.uns['spatial']['Br6522_ant']['images']['hires']
row = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_ant', 0] * adata.uns['spatial']['Br6522_ant']['scalefactors']['tissue_hires_scalef']
col = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_ant', 1] * adata.uns['spatial']['Br6522_ant']['scalefactors']['tissue_hires_scalef']

plt.figure(figsize=(6, 6))
plt.imshow(image)
plt.scatter(row, col, color='red', s=1)
plt.show()
No description has been provided for this image
In [8]:
spot_diameter_fullres = adata.uns['spatial']['Br6522_ant']['scalefactors']['spot_diameter_fullres']
spot_radius_full_res = spot_diameter_fullres/2
tissue_hires_scalef = adata.uns['spatial']['Br6522_ant']['scalefactors']['tissue_hires_scalef']
spot_radius_hires = spot_radius_full_res * tissue_hires_scalef

print(f"The radius of spot in high resolution image is {spot_radius_hires:.2f} pixels.")
The radius of spot in high resolution image is 5.77 pixels.
In [9]:
sample_id = "Br6522_ant"

# Extract spot coordinates
spots_coords = adata.obsm['spatial'][adata.obs['sample_id'] == sample_id]
spot_radius = 5.77

# Count the total number of spots
num_spots = len(spots_coords)

# Display the total number of spots
print(f"Sample ID: {sample_id}")
print(f"Total Number of Spots: {num_spots}")
Sample ID: Br6522_ant
Total Number of Spots: 4263
In [10]:
print(f"Image dimensions: Width = {image.shape[1]}, Height = {image.shape[0]}")
Image dimensions: Width = 1658, Height = 2000

This is clearly out of bounds, and needs to be scaled to match the pixel dimensions of the image.

In [11]:
scaling_factor_x = 1658 / np.max(spots_coords[:, 0])
scaling_factor_y = 2000 / np.max(spots_coords[:, 1])
spots_coords_scaled = spots_coords * [scaling_factor_x, scaling_factor_y]
print(f"Scaled Spots Range: Min X = {np.min(spots_coords_scaled[:, 0])}, Max X = {np.max(spots_coords_scaled[:, 0])}")
print(f"Scaled Spots Range: Min Y = {np.min(spots_coords_scaled[:, 1])}, Max Y = {np.max(spots_coords_scaled[:, 1])}")
Scaled Spots Range: Min X = 253.15847449782467, Max X = 1657.9999999999998
Scaled Spots Range: Min Y = 558.8632734122272, Max Y = 2000.0

Feature Extraction:¶

VIT part¶

In [12]:
from PIL import Image
import numpy as np
from transformers import ViTModel, ViTImageProcessor
import torch

model_name = "google/vit-base-patch16-224-in21k"
model = ViTModel.from_pretrained(model_name)  # Pre-trained Vision Transformer
processor = ViTImageProcessor.from_pretrained(model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()

# Extract features for selected spots
vector_representations = {}

with torch.no_grad():
    for idx, (x, y) in enumerate(spots_coords_scaled):
        start_x = max(0, int(x - spot_radius))
        end_x = min(image.shape[1], int(x + spot_radius))
        start_y = max(0, int(y - spot_radius))
        end_y = min(image.shape[0], int(y + spot_radius))
        
        if start_x >= end_x or start_y >= end_y:
            print(f"Skipping spot {idx+1}: Invalid crop boundaries.")
            continue
        
        # Crop the image
        cropped_image = image[start_y:end_y, start_x:end_x]
        cropped_pil = Image.fromarray((cropped_image * 255).astype(np.uint8))
        
        # Preprocess the image for ViT
        inputs = processor(images=cropped_pil, return_tensors="pt", size=(224, 224))
        inputs = {key: val.to(device) for key, val in inputs.items()}
        
        # Forward pass through ViT
        outputs = model(**inputs)
        cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze(0).cpu().numpy()
        
        # Store the embedding
        vector_representations[f"spot_{idx+1}"] = cls_embedding

Clustering¶

In [13]:
from sklearn.decomposition import PCA

vectors = np.array(list(vector_representations.values()))

pca = PCA(n_components=2)
pcs = pca.fit_transform(vectors)

pv1 = pca.components_[0]
pv2 = pca.components_[1]

# Store the principal components in a DataFrame
principalX = pd.DataFrame(data=pcs, columns=['PC1', 'PC2'])

principalX.head(10)
Out[13]:
PC1 PC2
0 -2.300600 1.023820
1 0.352776 -1.670902
2 -2.373689 1.021301
3 3.593608 0.635154
4 3.587902 0.634732
5 2.577854 -0.348189
6 3.548557 0.559539
7 0.646632 -2.937982
8 -2.120048 0.328270
9 -2.395325 0.715087
In [14]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

vectors = np.array(list(vector_representations.values()))

pca = PCA(n_components=2)
reduced_vectors = pca.fit_transform(vectors)

plt.figure(figsize=(8, 6))
plt.scatter(reduced_vectors[:, 0], reduced_vectors[:, 1], c='blue', edgecolor='k', s=60)
plt.title("PCA of Spot Embeddings", fontsize=16)
plt.xlabel("Principal Component 1", fontsize=12)
plt.ylabel("Principal Component 2", fontsize=12)
plt.grid(True)
plt.show()
No description has been provided for this image

There are 4 types of Manual Annotation, so we set K=4.

K Means Clustering¶

In [16]:
from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=4, random_state=42) 
clusters = kmeans.fit_predict(reduced_vectors)

plt.figure(figsize=(8, 6))
plt.scatter(reduced_vectors[:, 0], reduced_vectors[:, 1], c=clusters, cmap='viridis', edgecolor='k', s=10)
plt.colorbar(label="Cluster Label")
plt.title("PCA of Spot Embeddings with K-Means Clusters", fontsize=16)
plt.xlabel("Principal Component 1", fontsize=12)
plt.ylabel("Principal Component 2", fontsize=12)
plt.grid(True)
plt.show()
No description has been provided for this image
In [32]:
from umap import UMAP

umap = UMAP(n_components=2, random_state=42)
umap_results = umap.fit_transform(vectors)  # Use the original vectors before PCA

# Visualize the UMAP results
plt.figure(figsize=(8, 6))
plt.scatter(umap_results[:, 0], umap_results[:, 1], c=clusters, cmap='viridis', edgecolor='k', s=10)
plt.colorbar(label="Cluster Label")
plt.title("UMAP of Spot Embeddings", fontsize=16)
plt.xlabel("UMAP Component 1", fontsize=12)
plt.ylabel("UMAP Component 2", fontsize=12)
plt.grid(True)
plt.show
/home/ll2276/.conda/envs/new_env/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
Out[32]:
<function matplotlib.pyplot.show(close=None, block=None)>
No description has been provided for this image
In [30]:
# File path
file_path = "/vast/palmer/pi/xiting_yan/hw568/collections_spatial_datasets/spatialDLPFC_new/05-shared_utilities/nonIF/spatialLIBD_ManualAnnotation_Br6522_ant_wrinkle.csv"

data = pd.read_csv(file_path)
print(data.head(50))
annotation_counts = data['ManualAnnotation'].value_counts()
print(data.shape[0])
print(annotation_counts)
     sample_id           spot_name ManualAnnotation
0   Br6522_ant  AAACGCTGGGCACGAC-1           Fold_1
1   Br6522_ant  AAACGGGCGTACGGGT-1        Wrinkle_8
2   Br6522_ant  AAACTAACGTGGCGAC-1          Shear_3
3   Br6522_ant  AAAGGCTCTCGCGCCG-1        Wrinkle_5
4   Br6522_ant  AAATAGGGTGCTATTG-1        Wrinkle_5
5   Br6522_ant  AACATAGCGTGTATCG-1          Shear_1
6   Br6522_ant  AACCATGGGATCGCTA-1        Wrinkle_1
7   Br6522_ant  AACCCGAGCAGAATCG-1        Wrinkle_7
8   Br6522_ant  AACCTTTAAATACGGT-1        Wrinkle_5
9   Br6522_ant  AACGAAAGTCGTCCCA-1        Wrinkle_6
10  Br6522_ant  AACGACCTCCTAGCCG-1           Fold_1
11  Br6522_ant  AACGTCGCTGCACTTC-1           Fold_1
12  Br6522_ant  AACTAGGCTTGGGTGT-1        Wrinkle_5
13  Br6522_ant  AACTTGCGTTCTCGCG-1        Wrinkle_1
14  Br6522_ant  AAGAAAGTTTGATGGG-1          Shear_2
15  Br6522_ant  AAGCGCAGGGCTTTGA-1        Wrinkle_8
16  Br6522_ant  AAGCGGCGTCATGGGT-1          Shear_3
17  Br6522_ant  AAGCGTCCCTCATCGA-1        Wrinkle_5
18  Br6522_ant  AAGCTATGGATTGACC-1          Shear_2
19  Br6522_ant  AAGGGTTTGATTTCAG-1        Wrinkle_6
20  Br6522_ant  AAGTAGAAGACCGGGT-1        Wrinkle_2
21  Br6522_ant  AAGTTTATGGGCCCAA-1        Wrinkle_8
22  Br6522_ant  AATACCTGATGTGAAC-1           Fold_1
23  Br6522_ant  AATAGGCACGACCCTT-1          Shear_2
24  Br6522_ant  AATAGTCCGTCCCGAC-1        Wrinkle_5
25  Br6522_ant  AATCCCGCTCAGAGCC-1          Shear_1
26  Br6522_ant  AATCGAGGTCTCAAGG-1          Shear_1
27  Br6522_ant  AATCGCCTCAGCGCCA-1        Wrinkle_4
28  Br6522_ant  AATCGGTATAGCCCTC-1           Fold_1
29  Br6522_ant  AATCGTGAGCCGAGCA-1        Wrinkle_8
30  Br6522_ant  AATGAGTTCGCATATG-1        Wrinkle_1
31  Br6522_ant  AATGTTAAGACCCTGA-1          Shear_1
32  Br6522_ant  AATTACGAGACCCATC-1        Wrinkle_6
33  Br6522_ant  AATTATACCCAGCAAG-1          Shear_1
34  Br6522_ant  AATTGCAGCAATCGAC-1        Wrinkle_8
35  Br6522_ant  ACAAGTAATTGTAAGG-1          Shear_2
36  Br6522_ant  ACAATGATTCTTCTAC-1          Shear_3
37  Br6522_ant  ACAATTGTGTCTCTTT-1        Wrinkle_5
38  Br6522_ant  ACACCCGAGAAATCCG-1        Wrinkle_2
39  Br6522_ant  ACAGGCTTGCCCGACT-1        Wrinkle_5
40  Br6522_ant  ACATAAGTCGTGGTGA-1        Wrinkle_6
41  Br6522_ant  ACATACAATCAAGCGG-1          Shear_1
42  Br6522_ant  ACATCCCGGCCATACG-1        Wrinkle_5
43  Br6522_ant  ACATCGCAATATTCGG-1          Shear_2
44  Br6522_ant  ACCAACCGCACTCCAC-1        Wrinkle_8
45  Br6522_ant  ACCACACGGTTGATGG-1          Shear_1
46  Br6522_ant  ACCAGTGCCCGGTCAA-1          Shear_3
47  Br6522_ant  ACCATCGTATATGGTA-1        Wrinkle_4
48  Br6522_ant  ACCCGGATGACGCATC-1        Wrinkle_4
49  Br6522_ant  ACCCGGTTACACTTCC-1        Wrinkle_4
547
ManualAnnotation
Wrinkle_5    90
Wrinkle_8    79
Shear_1      73
Wrinkle_4    72
Shear_2      56
Fold_1       49
Wrinkle_1    38
Wrinkle_2    26
Shear_3      24
Wrinkle_3    16
Wrinkle_6    11
Wrinkle_7    10
Wrinkle_9     3
Name: count, dtype: int64
In [24]:
from sklearn.preprocessing import LabelEncoder

# Encode categorical features
data['sample_id'] = LabelEncoder().fit_transform(data['sample_id'])
data['spot_name'] = LabelEncoder().fit_transform(data['spot_name'])
labels = LabelEncoder().fit_transform(data['ManualAnnotation'])

# Extract features for UMAP
features = data[['sample_id', 'spot_name']]

# Apply UMAP
reducer = umap.UMAP(n_components=2, random_state=42)
umap_results = reducer.fit_transform(features)

# Plot UMAP results
plt.figure(figsize=(10, 8))
scatter = plt.scatter(
    umap_results[:, 0],
    umap_results[:, 1],
    c=labels,
    cmap="Spectral",
    edgecolor="k",
    s=50,
)

plt.colorbar(scatter, label="Class Labels (ManualAnnotation)")
plt.title("UMAP Visualization of ManualAnnotation", fontsize=16)
plt.xlabel("UMAP Component 1", fontsize=12)
plt.ylabel("UMAP Component 2", fontsize=12)
plt.grid(True)
plt.show()
/home/ll2276/.conda/envs/new_env/lib/python3.12/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
No description has been provided for this image
In [9]:
Br6522_mid_image = adata.uns['spatial']['Br6522_mid']['images']['hires']
row_mid = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_mid', 0] * adata.uns['spatial']['Br6522_mid']['scalefactors']['tissue_hires_scalef']
col_mid = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_mid', 1] * adata.uns['spatial']['Br6522_mid']['scalefactors']['tissue_hires_scalef']

plt.figure(figsize=(6, 6))
plt.imshow(Br6522_mid_image)
plt.scatter(row_mid, col_mid, color='red', s=1)
plt.show()
No description has been provided for this image
In [9]:
Br6522_mid_image = adata.uns['spatial']['Br6522_mid']['images']['hires']
row_mid = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_mid', 0] * adata.uns['spatial']['Br6522_mid']['scalefactors']['tissue_hires_scalef']
col_mid = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_mid', 1] * adata.uns['spatial']['Br6522_mid']['scalefactors']['tissue_hires_scalef']

plt.figure(figsize=(6, 6))
plt.imshow(Br6522_mid_image)
plt.scatter(row_mid, col_mid, color='red', s=1)
plt.show()
No description has been provided for this image
In [9]:
Br6522_mid_image = adata.uns['spatial']['Br6522_mid']['images']['hires']
row_mid = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_mid', 0] * adata.uns['spatial']['Br6522_mid']['scalefactors']['tissue_hires_scalef']
col_mid = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br6522_mid', 1] * adata.uns['spatial']['Br6522_mid']['scalefactors']['tissue_hires_scalef']

plt.figure(figsize=(6, 6))
plt.imshow(Br6522_mid_image)
plt.scatter(row_mid, col_mid, color='red', s=1)
plt.show()
No description has been provided for this image
In [10]:
Br8667_post_image = adata.uns['spatial']['Br8667_post']['images']['hires']
row_post = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br8667_post', 0] * adata.uns['spatial']['Br8667_post']['scalefactors']['tissue_hires_scalef']
col_post = adata.obsm['spatial'][adata.obs['sample_id'] == 'Br8667_post', 1] * adata.uns['spatial']['Br8667_post']['scalefactors']['tissue_hires_scalef']

plt.figure(figsize=(6, 6))
plt.imshow(Br8667_post_image)
plt.scatter(row_post, col_post, color='red', s=1)
plt.show()
No description has been provided for this image
In [3]:
# Get all sample IDs
sample_ids = list(adata.uns['spatial'].keys())

# Define the grid size
grid_size = 8
num_samples = len(sample_ids)
num_rows = math.ceil(num_samples / grid_size)

# Create a figure for the grid
fig, axes = plt.subplots(num_rows, grid_size, figsize=(grid_size * 2, num_rows * 2))

# Flatten the axes for easier indexing
axes = axes.flatten()

# Iterate over each sample and plot
for i, sample_id in enumerate(sample_ids):
    # Access the image and spatial data
    sample_data = adata.uns['spatial'][sample_id]
    image = sample_data['images']['hires']
    
    row = (
        adata.obsm['spatial'][adata.obs['sample_id'] == sample_id, 0]
        * sample_data['scalefactors']['tissue_hires_scalef']
    )
    col = (
        adata.obsm['spatial'][adata.obs['sample_id'] == sample_id, 1]
        * sample_data['scalefactors']['tissue_hires_scalef']
    )
    
    # Plot the image and points in the corresponding subplot
    ax = axes[i]
    ax.imshow(image)
    ax.set_title(f"Sample: {sample_id}", fontsize=8)
    ax.axis("off")

# Turn off any unused subplots
for j in range(num_samples, len(axes)):
    axes[j].axis("off")

# Adjust layout and display
plt.tight_layout()
plt.show()
No description has been provided for this image